import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import softmax
from torch_scatter import scatter_add

from Network.network import Network

class GINConv(MessagePassing):
    def __init__(self, emb_dim, out_dim, aggr="add", **kwargs):
        kwargs.setdefault('aggr', aggr)
        self.aggr = aggr
        super(GINConv, self).__init__(**kwargs)
        # multi-layer perceptron
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.ReLU(),
                                       torch.nn.Linear(2 * emb_dim, out_dim))

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return self.mlp(aggr_out)


class GCNConv(MessagePassing):

    def __init__(self, emb_dim, aggr="add"):
        super(GCNConv, self).__init__()

        self.emb_dim = emb_dim
        self.linear = torch.nn.Linear(emb_dim, emb_dim)

        self.aggr = aggr

    def norm(self, edge_index, num_nodes, dtype):
        ### assuming that self-loops have been already added in edge_index
        edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_attr):
        norm = self.norm(edge_index, x.size(0), x.dtype)
        x = self.linear(x)
        return self.propagate(self.aggr, edge_index, x=x, norm=norm)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * (x_j + edge_attr)


class GATConv(MessagePassing):
    def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr="add"):
        super(GATConv, self).__init__(node_dim=0)
        self.aggr = aggr
        self.heads = heads
        self.emb_dim = emb_dim
        self.negative_slope = negative_slope
        self.weight_linear = nn.Linear(emb_dim, heads * emb_dim)
        self.att = nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim))
        self.bias = nn.Parameter(torch.Tensor(emb_dim))
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.att)
        zeros(self.bias)

    def forward(self, x, edge_index):
        x = self.weight_linear(x)
        return self.propagate(edge_index[0], x=x)

    def message(self, edge_index, x_i, x_j):
        x_i = x_i.view(-1, self.heads, self.emb_dim)
        x_j = x_j.view(-1, self.heads, self.emb_dim)
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index[0])
        return x_j * alpha.view(-1, self.heads, 1)

    def update(self, aggr_out):
        aggr_out = aggr_out.mean(dim=1)
        aggr_out += self.bias
        return aggr_out


class GraphSAGEConv(MessagePassing):
    def __init__(self, emb_dim, aggr="mean"):
        super(GraphSAGEConv, self).__init__()
        self.emb_dim = emb_dim
        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.aggr = aggr

    def forward(self, x, edge_index):
        x = self.linear(x)
        return self.propagate(self.aggr, edge_index, x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return F.normalize(aggr_out, p=2, dim=-1)


class GraphNetwork(Network):  # basic 1d conv network
    def __init__(self, args):
        super().__init__(args)
        self.object_dim = args.object_dim
        self.output_dim = args.output_dim  # slightly different in meaning from num_outputs
        self.is_crelu = args.activation == "crelu"
        self.enc_dim = args.keypair.embed_dim
        self.msg_passing = args.keypair
        self.JK = args.keypair.JK
        self.drop_ratio = args.keypair.drop_ratio
        self.model = torch.nn.ModuleList()
        self.num_gnn_layers = args.keypair.num_layer

        self.k_linear = nn.Linear(self.object_dim, args.keypair.embed_dim) # changed dim to self.object_dim from args.single_obj_dim/args.obj_dim
        self.q_linear = nn.Linear(self.object_dim, args.keypair.embed_dim) # changed dim to self.object_dim from args.single_obj_dim/args.obj_dim
        for layer in range(self.num_gnn_layers):
            if args.keypair.net_type == "gin":
                self.model.append(GINConv(args.keypair.embed_dim, args.keypair.embed_dim, aggr="add"))

        self.batch_norms = torch.nn.ModuleList()
        for layer in range(self.num_gnn_layers):
            self.batch_norms.append(torch.nn.BatchNorm1d(args.keypair.embed_dim))

        if self.JK == "concat":
            self.node_mlp = torch.nn.Linear((self.num_gnn_layers + 1) * self.enc_dim, self.output_dim)
        else:
            self.node_mlp = torch.nn.Linear(self.enc_dim, self.output_dim)

        self.train()
        self.reset_network_parameters()

    def forward(self, x, edge_index):
        query, key = x
        print(self.object_dim, self.q_linear)
        query_enc = self.q_linear(query)
        key_enc = self.k_linear(key)
        x_enc = torch.cat((query_enc, key_enc), dim=1).view(-1, self.enc_dim)
        h_list = [x_enc]
        for layer in range(self.num_gnn_layers):
            h = self.model[layer](h_list[layer], edge_index)
            h = self.batch_norms[layer](h)
            if layer == self.num_gnn_layers - 1:
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
            h_list.append(h)
        if self.JK == "concat":
            node_representation = self.node_mlp(torch.cat(h_list, dim=1))
        else:
            node_representation = self.node_mlp(h_list[-1])
        return node_representation
        '''row, col = edge_idx
        s2t = torch.stack([row, col], dim=0)
        t2s = torch.stack([col, row], dim=0)
        x_left, x_right = x
        x_left = self.q_linear(x_left)
        x_right = self.k_linear(x_right)
        node = (x_left, x_right)
        h_tgt_list = [x_right]
        h_source_list = [x_left]
        for layer in range(self.num_gnn_layers):
            if layer == 0:
                h_tgt = self.gnns[layer]((x_left, x_right), s2t, N=s_N, M=t_N)
                h_source = self.gnns[layer]((x_right, x_left), t2s, N=t_N, M=s_N)
            else:
                h_tgt = self.gnns[layer]((x_left, x_right), s2t)'''

        '''
        mp: message passing

        if mp:
            assert key is not None, mask is not None
            assert mask.shape[1] == 1
            bs = x.shape[0]
            num_obj = mask.shape[-1]
            for i in range(bs):
                sample_i = []
                for j in range(num_obj):
                    if mask[i, 0, j] == 0:
                        continue
                    else:
                        xt = x[i, :, j].view(1, -1, 1)
                        xt = self.model(xt)
                        xt = self.activation_final(xt).view(1, -1)
                        sample_i.append(xt)
                sample_i = torch.cat(sample_i) # [num_obj, 4]
                # aggr
                if self.mp_aggr == 'mean':
                    sample_aggr = torch.mean(sample_i, 1, True)

                # update
                key[i] += xt
            return key
        else:
            # expects shape [batch, point_dim, num_channels]
            x = self.model(x)
            x = self.activation_final(x)
            return x'''